import torch
import torch.nn as nn
import numpy as np
import pandas as pd

import smooth_dp_utils
import models
import data_utils
import utils
import create_networks
import data_generation

import argparse
import pickle

from tqdm import tqdm

import time


##################################################################################### 
#################################PARAMETERS########################################## 
##################################################################################### 


def parse_arguments():
    parser = argparse.ArgumentParser(description='Set parameters for the program.')

    parser.add_argument('--dev', type=str, default='cpu', help='Device to use')
    parser.add_argument('--N_EPOCHS', type=int, default=20, help='N EPOCHS train')
    parser.add_argument('--seed_n', type=int, default=0, help='Seed number')
    
    parser.add_argument('--beta', type=float, default=5., help='Beta Smooth')
    parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate')
    parser.add_argument('--N_batches', type=int, default=100, help='N Batches in one Epoch')
    parser.add_argument('--bs_X', type=int, default=16, help='How many floyd warshalls in a batch')
    parser.add_argument('--ps_f', type=float, default=0.01, help='How many paths in one floyd warshall (factor)')
    
    parser.add_argument('--load_model', type=int, default=0, help='Load previous model?')
    
    parser.add_argument('--prefix', type=str, default='softmax', help='Type of method')
    
    parser.add_argument('--Vs', type=float, default=-1, help='Nr sampling nodes')
    
    parser.add_argument('--path_data', type=str, default='./../data_exploration/', help='data path')

    return parser.parse_args()


# Parsing arguments
args = parse_arguments()

# Assigning arguments to variables
dev = args.dev
N_EPOCHS = args.N_EPOCHS
seed_n = args.seed_n

beta_smooth = args.beta
lr = args.lr
N_batches = args.N_batches
bs_X = args.bs_X
ps_f = args.ps_f

load_model = args.load_model

prefix = args.prefix

path_data = args.path_data
    
    
    
epochs_wait = 4

print(f'RUNNING WITH {dev}')

#dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##################################################################################### 
##################################################################################### 
#####################################################################################

df_features = pd.read_csv(f'{path_data}features_per_trip_useful.csv')

#df_trips = pd.read_csv('./../data_exploration/full_useful_trips.csv')

df_trips = pd.read_csv(f'{path_data}full_useful_trips.csv')

df_edges = pd.read_csv(f'{path_data}graph_clusters_0010_080.csv')

df_nodes = pd.read_csv(f'{path_data}nodes_clusters_0010_080.csv')
df_nodes['node_sorted'] = df_nodes['node_id_new']

# We want to train in part of the drivers
unique_drivers = df_trips['driver'].drop_duplicates()
selected_drivers = unique_drivers.sample(frac=0.7, random_state=seed_n)
df_trips = df_trips[df_trips['driver'].isin(selected_drivers)]

df_trips = df_trips[df_trips.groupby('trip_id_new').node_id.transform('nunique')>1]
df_trips = df_trips.sort_values(by=['driver','trip_id_new','date_time'])
df_features['day_of_Week'] = df_features['day_of_Week'].astype(int).map({
    0: 0, 1: 0, 2: 0, 3: 0, 4: 1,
    5: 2, 
    6: 3 })

df_features = pd.get_dummies(df_features, columns=['day_of_Week'])
df_features['time_start'] = (df_features['time_start'] - df_features['time_start'].min()) / (df_features['time_start'].max() - df_features['time_start'].min())

indices_trips = df_trips[['trip_id','driver','trip_id_new']].drop_duplicates()
df_features = indices_trips.merge(df_features, on=['trip_id','driver'], how='left')
df_features.iloc[:,-4:] = df_features.iloc[:,-4:].astype(int)

feats = ['day_of_Week_0','day_of_Week_1','day_of_Week_2','day_of_Week_3','is_Holiday','time_start']
n_features = len(feats)

n_trips = len(df_features)

prior_M, edges_prior, M_indices = data_utils.get_prior_and_M_indices(df_nodes, df_edges)

assert (df_trips.trip_id_new.unique() == df_features.trip_id_new.unique()).all()

trip_ids = df_trips.trip_id_new.unique()

V = M_indices.max()+1
X_np = np.array(df_features[feats])
node_idx_sequence_trips = df_trips.groupby('trip_id_new')['node_id'].apply(list)
edges_seq_original = node_idx_sequence_trips.apply(lambda x: np.column_stack([x[:-1], x[1:]]))
start_nodes_original = node_idx_sequence_trips.apply(lambda x: x[0])
end_nodes_original = node_idx_sequence_trips.apply(lambda x: x[-1])

edges_idx_on_original = np.zeros((len(edges_seq_original), len(M_indices)), dtype=int)
edges_seq_original_np = np.array(edges_seq_original)

N_train = len(edges_seq_original)

print('Processing Data')
for i in tqdm(range(len(edges_seq_original))):
    matching_indices = []
    for row in edges_seq_original_np[i]:
        idx = np.where(np.isin(M_indices[:,0], row[0])*np.isin(M_indices[:,1], row[1]))[0].item()
        edges_idx_on_original[i, idx] = 1

edges_seq_original = list(edges_seq_original)
node_idx_sequence_trips = list(node_idx_sequence_trips)

end_to_end_nodes_original = np.vstack((np.array(start_nodes_original), np.array(end_nodes_original))).T

edges_idx_on_original_tensor = torch.tensor(edges_idx_on_original, dtype=torch.float32)

sn_ohe = torch.zeros((N_train, V))
en_ohe = torch.zeros((N_train, V))
for i in range(N_train):
    sn_ohe[i, end_to_end_nodes_original[i, 0]] = 1
    en_ohe[i, end_to_end_nodes_original[i, 1]] = 1

X = torch.tensor(X_np, dtype=torch.float32)

# Should we use nodes sampling during training?
ps_in_batch = int(ps_f*X.shape[0])

Vs = int(args.Vs)
factor_s=0.01
bool_scale = False
if Vs < V and Vs > 0:
    bool_scale = True
else:
    Vs = V    

##################################################################################### 
######################### MODEL LOAD OR CREATE ###################################### 
#####################################################################################

print('----- Model Load or Create -----')

inp_s_model = X.shape[-1]
if prefix == 'fcnn':
    inp_s_model = inp_s_model + 2*V

model = models.ANNVar(input_size=inp_s_model, output_size=len(M_indices), hl_sizes=[1024, 1024])
model = model.to(dev)

mse = nn.MSELoss(reduction='none')
bce = nn.BCELoss(reduction='none')

softmax = nn.Softmax(-1)

sigmo = nn.Sigmoid()

solver_spo = smooth_dp_utils.SolverDiff()
solver_spo.set_parameters(lambda_val=1000, prior_M=prior_M, M_indices=M_indices)

def cross_entropy_cont(target, prediction):
    return -torch.sum(target * torch.log(prediction+0.00001), -1)

criterion = torch.nn.KLDivLoss(reduction='none')
def cross_entropy_cont(target, prediction):
    return criterion(torch.log(prediction + 0.00001), target).sum(-1)

model_path = f'saved_models/cabspot_{prefix}_{Vs}_{seed_n}_{ps_f}.pkl'
model_path_inter = f'saved_models/cabspot_inter_{prefix}_{Vs}_{seed_n}_{ps_f}'

if load_model:
    try:
        model = models.ANNVar(input_size=inp_s_model, output_size=len(M_indices), hl_sizes=[1024, 1024])
        model.load_state_dict(
            torch.load(model_path, map_location=torch.device(dev)))
        model = model.to(dev)
        print('MODEL LOADED')
    except:
        print('FAILED TO LOAD')
        pass
else:
    print('MODEL CREATED')
    pass


model = model.to(dev)
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=10e-5)
print('MODEL ON ', next(model.parameters()).device)

print('----- Model Load or Create Finished -----')

##################################################################################### 
##################################################################################### 
#####################################################################################



prior_sigmaY = 0.2*prior_M[M_indices[:,0], M_indices[:,1]].to(dev)
prior_dY = torch.zeros_like(prior_sigmaY).to(dev)
prior_M = prior_M.to(dev)
M_indices = M_indices.to(dev)

##TO REMOVE
#prior_M = torch.zeros_like(prior_M)

elements, frequencies = utils.get_nodes_and_freqs(node_idx_sequence_trips)

not_best_count_accum = 0
loss_batch_avg_best = torch.inf

for epochs in range(0,N_EPOCHS):
    
    loss_batch_avg = 0
    
    for batch in range(0, N_batches):     
        
        if prefix == 'fcnn':
            opt.zero_grad()
            idcs_batch = utils.generate_n_combinations(X, ps_in_batch-1, bs_X)
            start_time = time.time()
            
            X_batch = X[idcs_batch[:,0]].to(dev)
            sn_batch = sn_ohe[idcs_batch[:,0]].to(dev)
            en_batch = en_ohe[idcs_batch[:,0]].to(dev)
            
            X_batch = torch.hstack((sn_batch, en_batch, X_batch))
            
            OutNN_batch, _ = model(X_batch)
            edges_on_pred = sigmo(OutNN_batch)
            
            edges_on_true = edges_idx_on_original_tensor[idcs_batch[:,0]].to(dev)
            
            loss = bce(edges_on_pred, edges_on_true).sum(1).mean()
            
            #import pdb
            #pdb.set_trace()
            
            loss.backward()
            opt.step()
            
            #print('Batch', batch, round(loss.item(), 3), 
            #'\tTime: ', round(time.time() - start_time, 3))
            
            loss_batch_avg += (loss/N_batches).detach()
            
            continue
            
            
        if prefix == 'dbcs':
            opt.zero_grad()
            idcs_batch = utils.generate_n_combinations(X, ps_in_batch-1, bs_X)
            start_time = time.time()
            
            X_batch = X[idcs_batch[:,0]].to(dev)
            
            dY = model(X_batch)[0]
                
            edges_on_true = edges_idx_on_original_tensor[idcs_batch[:,0]].to(dev)            
       
            edges_on_pred = solver_spo.apply(
                    dY, end_to_end_nodes_original[idcs_batch[:,0]]
                )  
            
            loss = bce(edges_on_pred, edges_on_true).sum(1).mean()
            
            loss.backward()
            opt.step()
            
            print('Batch', batch, round(loss.item(), 3), 
            '\tTime: ', round(time.time() - start_time, 3))
            
            loss_batch_avg += (loss/N_batches).detach()
            
            continue
                

        if bool_scale:
            selected_indexes, selected_trips, nodes_selected, nodes_excluded = utils.selected_trips_and_idx(
                node_idx_sequence_trips, M_indices, elements, frequencies, Vs, V)
            if selected_indexes == None:
                continue
            X_selected = X[selected_indexes]
        else:
            selected_trips = node_idx_sequence_trips
            X_selected = X
        
        idcs_batch = utils.generate_n_combinations(X_selected, ps_in_batch-1, bs_X)
              
        start_time = time.time() 

        loss_batch = torch.tensor(0.)

        opt.zero_grad()

        X_batch = X_selected[idcs_batch[:,0]].to(dev)
        dY, dsigmaY = model(X_batch)                
        sigmaY = dsigmaY + prior_sigmaY                     
        M_Y_pred = utils.costs_to_matrix(prior_M, M_indices, dY)
        M_sigmaY = utils.costs_to_matrix(0.0*prior_M, M_indices, sigmaY)
        
        
        if bool_scale:        
            M_Y_pred_selected, M_sigmaY_selected, M_indices_selected_mapped = utils.select_Ms_from_selected_idx_and_trips(
                M_Y_pred, M_sigmaY, Vs, M_indices, nodes_excluded, nodes_selected, torch.tensor(beta_smooth), dev)            
        else:
            M_Y_pred_selected = M_Y_pred
            M_sigmaY_selected = M_sigmaY
            M_indices_selected_mapped = M_indices
            
        k_nodes = torch.arange(Vs)
        k_nodes_shufled = k_nodes[torch.randperm(Vs)]
        shuffle_k_dict = {int(k_nodes_shufled[i]):int(k_nodes[i]) for i in range(Vs)} 
        shuffle_k_inv_dict = {int(k_nodes[i]):int(k_nodes_shufled[i]) for i in range(Vs)}    
        
        # We want to remove bias of node ordering
        M_Y_pred_selected_shuf = M_Y_pred_selected[:,k_nodes_shufled][:, :, k_nodes_shufled]     
        M_indices_selected_mapped_shuf = M_indices_selected_mapped.clone()
        for key, value in shuffle_k_dict.items():
            M_indices_selected_mapped_shuf[M_indices_selected_mapped == key] = value           
        selected_trips_shuf = [[shuffle_k_dict[p] for p in sublist] for sublist in selected_trips] 
                                 
        if prefix == 'softmax':
            probs_pred = smooth_dp_utils.smooth_floyd_warshall_batch_adapted_parallel(
                            M_Y_pred_selected_shuf,
                            M_indices_selected_mapped_shuf,
                            dev,
                            torch.tensor(beta_smooth)
            ) 
            
        else:
            print('Check prefix variable, Breaking!!!')
            break
                       
        mib = data_utils.get_m_inter_batch(selected_trips_shuf, idcs_batch, Vs, Vs)
        mib = torch.tensor(mib, dtype=torch.float32).to(dev)
        m_inter_total = mib.sum(1)/mib.sum(1).sum(-1).unsqueeze(-1)
        
        mask = ~torch.isnan(m_inter_total)
        true_paths_dist = m_inter_total[mask].reshape(-1, Vs)
        pred_paths_dist = probs_pred[mask].reshape(-1, Vs)
        loss_mse = cross_entropy_cont(true_paths_dist, pred_paths_dist).mean()
                
        #reg = (prior_sigmaY.exp().pow(-2) * (dY - prior_dY).pow(2) + 
        #                         (torch.log(sigmaY.clip(0.001)) - prior_sigmaY.log()).exp().pow(2) - 1 - 
        #                         2*(torch.log(sigmaY.clip(0.001)) - prior_sigmaY.log())).mean()
        
        reg = (dY - prior_dY).pow(2).mean()
        
        loss_total = loss_mse + 0.0000001*reg
        loss_total.backward()
        opt.step()
        
        print('Batch', batch, round(loss_mse.item(), 3), round(reg.item(), 3), 
              #round(loss_sigma.item(), 3),
              '\tTime: ', round(time.time() - start_time, 3))
                
        loss_batch_avg += (loss_mse/N_batches).detach()
            
    if loss_batch_avg>=loss_batch_avg_best:
        not_best_count_accum = not_best_count_accum + 1
        print('Did not improve results nr ', not_best_count_accum)
    else:
        _ = utils.check_or_create_folder("saved_models")
        torch.save(model.state_dict(), model_path_inter + f'_{epochs}.pkl') 
        loss_batch_avg_best = loss_batch_avg
        not_best_count_accum = 0
    
    _ = utils.check_or_create_folder("saved_models")
    torch.save(model.state_dict(), model_path)            
    print('Batches AVG:', loss_batch_avg.item())
    
    if not_best_count_accum >= epochs_wait:
        print('Converged, exiting')
        break
